{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "#skip\n",
    "! [ -e /content ] && pip install -Uqq fastai  # upgrade fastai on colab"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "from fastai.torch_basics import *\n",
    "from fastai.tabular.core import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "from nbdev.showdoc import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#default_exp tabular.model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Tabular model\n",
    "\n",
    "> A basic model that can be used on tabular data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Embeddings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def emb_sz_rule(n_cat):\n",
    "    \"Rule of thumb to pick embedding size corresponding to `n_cat`\"\n",
    "    return min(600, round(1.6 * n_cat**0.56))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def _one_emb_sz(classes, n, sz_dict=None):\n",
    "    \"Pick an embedding size for `n` depending on `classes` if not given in `sz_dict`.\"\n",
    "    sz_dict = ifnone(sz_dict, {})\n",
    "    n_cat = len(classes[n])\n",
    "    sz = sz_dict.get(n, int(emb_sz_rule(n_cat)))  # rule of thumb\n",
    "    return n_cat,sz"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Through trial and error, this general rule takes the lower of two values:\n",
    "* A dimension space of 600\n",
    "* A dimension space equal to 1.6 times the cardinality of the variable to 0.56.\n",
    "\n",
    "This provides a good starter for a good embedding space for your variables. For more advanced users who wish to lean into this practice, you can tweak these values to your discretion. It is not uncommon for slight adjustments to this general formula to provide more success."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def get_emb_sz(to, sz_dict=None):\n",
    "    \"Get default embedding size from `TabularPreprocessor` `proc` or the ones in `sz_dict`\"\n",
    "    return [_one_emb_sz(to.classes, n, sz_dict) for n in to.cat_names]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "class TabularModel(Module):\n",
    "    \"Basic model for tabular data.\"\n",
    "    def __init__(self, emb_szs, n_cont, out_sz, layers, ps=None, embed_p=0.,\n",
    "                 y_range=None, use_bn=True, bn_final=False, bn_cont=True, act_cls=nn.ReLU(inplace=True),\n",
    "                 lin_first=True):\n",
    "        ps = ifnone(ps, [0]*len(layers))\n",
    "        if not is_listy(ps): ps = [ps]*len(layers)\n",
    "        self.embeds = nn.ModuleList([Embedding(ni, nf) for ni,nf in emb_szs])\n",
    "        self.emb_drop = nn.Dropout(embed_p)\n",
    "        self.bn_cont = nn.BatchNorm1d(n_cont) if bn_cont else None\n",
    "        n_emb = sum(e.embedding_dim for e in self.embeds)\n",
    "        self.n_emb,self.n_cont = n_emb,n_cont\n",
    "        sizes = [n_emb + n_cont] + layers + [out_sz]\n",
    "        actns = [act_cls for _ in range(len(sizes)-2)] + [None]\n",
    "        _layers = [LinBnDrop(sizes[i], sizes[i+1], bn=use_bn and (i!=len(actns)-1 or bn_final), p=p, act=a, lin_first=lin_first)\n",
    "                       for i,(p,a) in enumerate(zip(ps+[0.],actns))]\n",
    "        if y_range is not None: _layers.append(SigmoidRange(*y_range))\n",
    "        self.layers = nn.Sequential(*_layers)\n",
    "\n",
    "    def forward(self, x_cat, x_cont=None):\n",
    "        if self.n_emb != 0:\n",
    "            x = [e(x_cat[:,i]) for i,e in enumerate(self.embeds)]\n",
    "            x = torch.cat(x, 1)\n",
    "            x = self.emb_drop(x)\n",
    "        if self.n_cont != 0:\n",
    "            if self.bn_cont is not None: x_cont = self.bn_cont(x_cont)\n",
    "            x = torch.cat([x, x_cont], 1) if self.n_emb != 0 else x_cont\n",
    "        return self.layers(x)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This model expects your `cat` and `cont` variables seperated. `cat` is passed through an `Embedding` layer and potential `Dropout`, while `cont` is passed though potential `BatchNorm1d`. Afterwards both are concatenated and passed through a series of `LinBnDrop`, before a final `Linear` layer corresponding to the expected outputs. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "emb_szs = [(4,2), (17,8)]\n",
    "m = TabularModel(emb_szs, n_cont=2, out_sz=2, layers=[200,100]).eval()\n",
    "x_cat = torch.tensor([[2,12]]).long()\n",
    "x_cont = torch.tensor([[0.7633, -0.1887]]).float()\n",
    "out = m(x_cat, x_cont)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "@delegates(TabularModel.__init__)\n",
    "def tabular_config(**kwargs):\n",
    "    \"Convenience function to easily create a config for `TabularModel`\"\n",
    "    return kwargs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Any direct setup of `TabularModel`'s internals should be passed through here:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'embed_p': 0.6, 'use_bn': False}"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "config = tabular_config(embed_p=0.6, use_bn=False); config"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Export -"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Converted 00_torch_core.ipynb.\n",
      "Converted 01_layers.ipynb.\n",
      "Converted 01a_losses.ipynb.\n",
      "Converted 02_data.load.ipynb.\n",
      "Converted 03_data.core.ipynb.\n",
      "Converted 04_data.external.ipynb.\n",
      "Converted 05_data.transforms.ipynb.\n",
      "Converted 06_data.block.ipynb.\n",
      "Converted 07_vision.core.ipynb.\n",
      "Converted 08_vision.data.ipynb.\n",
      "Converted 09_vision.augment.ipynb.\n",
      "Converted 09b_vision.utils.ipynb.\n",
      "Converted 09c_vision.widgets.ipynb.\n",
      "Converted 10_tutorial.pets.ipynb.\n",
      "Converted 10b_tutorial.albumentations.ipynb.\n",
      "Converted 11_vision.models.xresnet.ipynb.\n",
      "Converted 12_optimizer.ipynb.\n",
      "Converted 13_callback.core.ipynb.\n",
      "Converted 13a_learner.ipynb.\n",
      "Converted 13b_metrics.ipynb.\n",
      "Converted 14_callback.schedule.ipynb.\n",
      "Converted 14a_callback.data.ipynb.\n",
      "Converted 15_callback.hook.ipynb.\n",
      "Converted 15a_vision.models.unet.ipynb.\n",
      "Converted 16_callback.progress.ipynb.\n",
      "Converted 17_callback.tracker.ipynb.\n",
      "Converted 18_callback.fp16.ipynb.\n",
      "Converted 18a_callback.training.ipynb.\n",
      "Converted 18b_callback.preds.ipynb.\n",
      "Converted 19_callback.mixup.ipynb.\n",
      "Converted 20_interpret.ipynb.\n",
      "Converted 20a_distributed.ipynb.\n",
      "Converted 21_vision.learner.ipynb.\n",
      "Converted 22_tutorial.imagenette.ipynb.\n",
      "Converted 23_tutorial.vision.ipynb.\n",
      "Converted 24_tutorial.siamese.ipynb.\n",
      "Converted 24_vision.gan.ipynb.\n",
      "Converted 30_text.core.ipynb.\n",
      "Converted 31_text.data.ipynb.\n",
      "Converted 32_text.models.awdlstm.ipynb.\n",
      "Converted 33_text.models.core.ipynb.\n",
      "Converted 34_callback.rnn.ipynb.\n",
      "Converted 35_tutorial.wikitext.ipynb.\n",
      "Converted 36_text.models.qrnn.ipynb.\n",
      "Converted 37_text.learner.ipynb.\n",
      "Converted 38_tutorial.text.ipynb.\n",
      "Converted 39_tutorial.transformers.ipynb.\n",
      "Converted 40_tabular.core.ipynb.\n",
      "Converted 41_tabular.data.ipynb.\n",
      "Converted 42_tabular.model.ipynb.\n",
      "Converted 43_tabular.learner.ipynb.\n",
      "Converted 44_tutorial.tabular.ipynb.\n",
      "Converted 45_collab.ipynb.\n",
      "Converted 46_tutorial.collab.ipynb.\n",
      "Converted 50_tutorial.datablock.ipynb.\n",
      "Converted 60_medical.imaging.ipynb.\n",
      "Converted 61_tutorial.medical_imaging.ipynb.\n",
      "Converted 65_medical.text.ipynb.\n",
      "Converted 70_callback.wandb.ipynb.\n",
      "Converted 71_callback.tensorboard.ipynb.\n",
      "Converted 72_callback.neptune.ipynb.\n",
      "Converted 73_callback.captum.ipynb.\n",
      "Converted 74_callback.azureml.ipynb.\n",
      "Converted 97_test_utils.ipynb.\n",
      "Converted 99_pytorch_doc.ipynb.\n",
      "Converted dev-setup.ipynb.\n",
      "Converted index.ipynb.\n",
      "Converted quick_start.ipynb.\n",
      "Converted tutorial.ipynb.\n"
     ]
    }
   ],
   "source": [
    "#hide\n",
    "from nbdev.export import notebook2script\n",
    "notebook2script()"
   ]
  }
 ],
 "metadata": {
  "jupytext": {
   "split_at_heading": true
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
